'''
Description: 
Author: KouXichao
Date: 2020-11-18 19:39:48
LastEditors: KouXichao
LastEditTime: 2020-11-26 14:31:36
FilePath: \PaddleOCR\eval_acc_lineimg.py
Github: https://github.com/kouxichao
'''

import os
import argparse
import numpy as np
import cv2
import Levenshtein
import math
import shutil  
import re
from zhon.hanzi import punctuation 
import string

# 导入识别模块
import tools.infer.predict_rec as predict_rec
import tools.infer.utility as utility
 
if __name__ == '__main__':

        # # Clear the line image directory if it exists and Create the directory.
        # self.line_image_path = os.path.join(layout_folder, 'line_images/')
        # if os.path.exists(self.line_image_path):  
        #     shutil.rmtree(self.line_image_path)
        # if os.path.exists(os.path.join(self.layout_folder, 'gt.txt')):
        #     os.remove(os.path.join(self.layout_folder, 'gt.txt')) 
        # os.mkdir(self.line_image_path)  
        
    args = utility.parse_args()

    data_path = args.image_dir
    # data for ocr
    imagesOfLine = []
    textOfLine = []
    line_image_name_list = []
    with open(os.path.join(data_path, "gt.txt"), "r") as f:
        for line in f.readlines():
            line_gt = line.rstrip('\n')
            line_gt = line.split('\t')
            image_path = os.path.join(data_path, line_gt[0])
            line_image_name_list.append(image_path)
            textOfLine.append(line_gt[1])

            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            imagesOfLine.append(image)        

    # recognization
    print("\033[31mRecognization start(total number:{}) !!!\033[0m".format(len(imagesOfLine)))
    textRecognizer = predict_rec.TextRecognizer(args)
    rec_res, elapse = textRecognizer(imagesOfLine)
    print("\033[31mline num: {}, rec_res num  : {}, elapse : {}\033[0m".format(len(rec_res), len(imagesOfLine), elapse))
    
    # caculate edit distance and accuracy of ocr
    assert len(imagesOfLine) == len(rec_res)
    dist = 0       
    dist_without_space = 0
    dist_without_punc = 0   
    dist_without_spacepunc = 0 
    total_char_num = 0
    total_char_num_without_punc = 0
    punc = punctuation + string.punctuation
    for o, res, imgname, img in zip(textOfLine, rec_res, line_image_name_list, imagesOfLine):
        # if len(o) > 30:
        #     continue

        #原始输出正确率
        gt = o   
        pre = res[0]  
        dist += Levenshtein.distance(gt, pre)
        total_char_num += len(o)
        print("line image: ", imgname)
        print("\t\033[33morig text: {}\033[0m\t\033[34mpred text: {}\033[0m".format(gt, pre))
        print()
        
        #可视化
        if(Levenshtein.distance(gt, pre) / float(len(o)) > 0.3):
            cv2.imwrite("./hard_line_img/" + imgname.split('/')[-1], img)
            # input()
            # cv2.namedWindow(o)
            # cv2.imshow(o, img)
            # cv2.waitKey()
            # cv2.destroyWindow(o)
        
        #去除空格正确率
        gt_without_space = ''.join(gt.split())
        pre_without_space = ''.join(pre.split())
        dist_without_space += Levenshtein.distance(gt_without_space, pre_without_space)
        print("\t\033[33morig text(without space): {}\033[0m\n\t\033[34mpred text(without space): {}\033[0m".format(gt_without_space, pre_without_space))
        print()
        
        #去除标点正确率
        gt_without_punc = re.sub(r"[%s]+" %(punc), "", gt) 
        pre_without_punc = re.sub(r"[%s]+" %(punc), "", pre)
        # pre_without_punc = ' '.join(pre_without_punc.split())
        dist_without_punc += Levenshtein.distance(gt_without_punc, pre_without_punc)
        # total_char_num_without_punc += len(gt_without_punc)
        print("\t\033[33morig text(without punc): {}\033[0m\t\033[34mpred text(without punc): {}\033[0m".format(gt_without_punc, pre_without_punc))
        print()
        
        #去除标点 和 空格
        gt_without_spacepunc = ''.join(gt_without_punc.split())
        pre_without_spacepunc = ''.join(pre_without_punc.split()) 
        dist_without_spacepunc += Levenshtein.distance(gt_without_spacepunc, pre_without_spacepunc)
        # total_char_num_without_punc += len(gt_without_punc)
        print("\t\033[33morig text(without space_punc): {}\033[0m\n\t\033[34mpred text(without space_punc): {}\033[0m".format(gt_without_spacepunc, pre_without_spacepunc))
               
print("***********Edit Distance: {}, Total Char Num: {}, Accuracy: {}**********".format(dist, total_char_num, 1-dist/total_char_num))
print("***********Edit Distance(no space): {}, Total Char Num: {}, Accuracy: {}**********".format(dist_without_space, total_char_num, 1-dist_without_space/total_char_num))
print("***********Edit Distance(no punc): {}, Total Char Num: {}, Accuracy: {}**********".format(dist_without_punc, total_char_num, 1-dist_without_punc/total_char_num))
print("***********Edit Distance(no space_punc): {}, Total Char Num: {}, Accuracy: {}**********".format(dist_without_spacepunc, total_char_num, 1-dist_without_spacepunc/total_char_num))

